Please submit one zip file on cate - CW2.zip containing the following:
Training the GAN will take quite a long time (multiple hours), and so you have four options:
TAs will run a testing cell (at the end of this notebook), so you are required to copy your data transform and denorm functions to a cell near the bottom of the document (it is demarkated). You are advised to check that your implementations pass these tests (in particular, the jit saving and loading may not work for certain niche functions)
You can feel free to add architectural alterations / custom functions outside of pre-defined code blocks, but if you manipulate the model's inputs in some way, please include the same code in the TA test cell, so our tests will run easily.
**The deadline for submission is 19:00, Friday 25th February, 2022**
You will need to install pytorch and import some utilities by running the following cell:
!pip install -q torch torchvision altair seaborn
!git clone -q https://github.com/afspies/icl_dl_cw2_utils
from icl_dl_cw2_utils.utils.plotting import plot_tsne
from pathlib import Path
from tqdm import tqdm
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
fatal: destination path 'icl_dl_cw2_utils' already exists and is not an empty directory.
Here we have some default pathing options which vary depending on the environment you are using. You can of course change these as you please.
WORKING_ENV = 'PAPERSPACE' # Can be LABS, COLAB or PAPERSPACE
USERNAME = 'afs219' # If working on Lab Machines - Your college username
assert WORKING_ENV in ['LABS', 'COLAB', 'PAPERSPACE']
if WORKING_ENV == 'COLAB':
from google.colab import drive
%load_ext google.colab.data_table
content_path = '/content/drive/MyDrive/dl_cw_2'
data_path = './data/'
drive.mount('/content/drive/') # Outputs will be saved in your google drive
elif WORKING_ENV == 'LABS':
content_path = '~/Documents/dl_cw_2' # You may want to change this
# Your python env and training data should be on bitbucket
data_path = f'/vol/bitbucket/{USERNAME}/dl_cw_data/'
else: # Using Paperspace
# Paperspace does not properly render animated progress bars
# Strongly recommend using the JupyterLab UI instead of theirs
!pip install ipywidgets
content_path = '/notebooks'
data_path = './data/'
content_path = Path(content_path)
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.8/site-packages (7.6.5)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (1.0.2)
Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.1.0)
Requirement already satisfied: widgetsnbextension~=3.5.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (3.5.2)
Requirement already satisfied: ipython-genutils~=0.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (0.2.0)
Requirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.1.3)
Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (6.4.1)
Requirement already satisfied: ipython>=4.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (7.28.0)
Requirement already satisfied: jupyter-client<8.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (7.0.6)
Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.3)
Requirement already satisfied: debugpy<2.0,>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.0)
Requirement already satisfied: tornado<7.0,>=4.2 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1)
Requirement already satisfied: setuptools>=18.5 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (58.2.0)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (3.0.20)
Requirement already satisfied: decorator in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (5.1.0)
Requirement already satisfied: backcall in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (4.8.0)
Requirement already satisfied: pickleshare in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.7.5)
Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.18.0)
Requirement already satisfied: pygments in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (2.10.0)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.2)
Requirement already satisfied: entrypoints in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (0.3)
Requirement already satisfied: pyzmq>=13 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (22.3.0)
Requirement already satisfied: jupyter-core>=4.6.0 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (4.8.1)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (2.8.2)
Requirement already satisfied: nest-asyncio>=1.5 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (1.5.1)
Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /opt/conda/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (4.0.1)
Requirement already satisfied: attrs>=17.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (21.2.0)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.18.0)
Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /opt/conda/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.5)
Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.1->jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (1.16.0)
Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.8/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.4.1)
Requirement already satisfied: argon2-cffi in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (21.1.0)
Requirement already satisfied: nbconvert in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (6.2.0)
Requirement already satisfied: terminado>=0.8.3 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.12.1)
Requirement already satisfied: Send2Trash>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.8.0)
Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.11.0)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.0.1)
Requirement already satisfied: cffi>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.14.6)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.0->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.20)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.0.1)
Requirement already satisfied: defusedxml in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)
Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.4)
Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)
Requirement already satisfied: jupyterlab-pygments in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.1.2)
Requirement already satisfied: bleach in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (4.1.0)
Requirement already satisfied: testpath in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.0)
Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)
Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (21.0)
Requirement already satisfied: webencodings in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.4.7)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
For this coursework, you are asked to implement two commonly used generative models:
For the first part you will the MNIST dataset https://en.wikipedia.org/wiki/MNIST_database and for the second the CIFAR-10 (https://www.cs.toronto.edu/~kriz/cifar.html).
Each part is worth 50 points.
The emphasis of both parts lies in understanding how the models behave and learn, however, some points will be available for getting good results with your GAN (though you should not spend too long on this).
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt
def show(img):
npimg = img.cpu().numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
if not os.path.exists(content_path/'CW_VAE/'):
os.makedirs(content_path/'CW_VAE/')
if not os.path.exists(data_path):
os.makedirs(data_path)
# We set a random seed to ensure that your results are reproducible.
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
GPU = True # Choose whether to use GPU
if GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
print(f'Using {device}')
Using cuda
# Necessary Hyperparameters
num_epochs = 10
learning_rate = 0.001
batch_size = 128
latent_dim = 8# Choose a value for the size of the latent space
# Additional Hyperparameters
# (Optionally) Modify transformations on input
transform = transforms.Compose([
transforms.ToTensor(),
])
# (Optionally) Modify the network's output for visualizing your images
def denorm(x):
return x
train_dat = datasets.MNIST(
data_path, train=True, download=True, transform=transform
)
test_dat = datasets.MNIST(data_path, train=False, transform=transform)
loader_train = DataLoader(train_dat, batch_size, shuffle=True)
loader_test = DataLoader(test_dat, batch_size, shuffle=False)
# Don't change
sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[:32, :, :, :]
save_image(fixed_input, content_path/'CW_VAE/image_original.png')
# *CODE FOR PART 1.1a IN THIS CELL*
class VAE(nn.Module):
def __init__(self, latent_dim, h_dim=4096, image_channels=1):
super(VAE, self).__init__()
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
self.encoder = nn.Sequential(
nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.fc_mu = nn.Linear(h_dim, latent_dim)
self.fc_logvar = nn.Linear(h_dim, latent_dim)
self.fc_out_lat = nn.Linear(latent_dim, h_dim)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=5, stride=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2),
nn.Sigmoid(),
)
self.epsilon = torch.distributions.Normal(0, 1)
self.h_dim = h_dim
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
def encode(self, x):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
h = self.encoder(x)
h_flatten = h.view(h.size(0), -1)
#import pdb; pdb.set_trace()
mu, logvar = self.fc_mu(h_flatten), self.fc_logvar(h_flatten)
z = self.reparametrize(mu, logvar)
return z, mu, logvar
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
def reparametrize(self, mu, logvar):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
std = logvar.mul(0.5).exp_()
esp = self.epsilon.rsample()
z = mu + std * esp
return z
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
def decode(self, z):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
z = self.fc_out_lat(z)
z_unflatten = z.view(z.size(0), self.h_dim, 1, 1)
z_decoded = self.decoder(z_unflatten)
return z_decoded
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
def forward(self, x):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
z, mu, logvar = self.encode(x)
z = self.decode(z)
return z, mu, logvar
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
model = VAE(latent_dim).to(device)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(model)
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Total number of parameters is: 14156881
VAE(
(encoder): Sequential(
(0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1))
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(128, 256, kernel_size=(4, 4), stride=(1, 1))
(10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
)
(fc_mu): Linear(in_features=4096, out_features=8, bias=True)
(fc_logvar): Linear(in_features=4096, out_features=8, bias=True)
(fc_out_lat): Linear(in_features=8, out_features=4096, bias=True)
(decoder): Sequential(
(0): ConvTranspose2d(4096, 128, kernel_size=(5, 5), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(1, 1))
(7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2))
(10): Sigmoid()
)
)
Recall the Beta VAE loss, with an encoder $q$ and decoder $p$: $$ \mathcal{L}=\mathbb{E}_{q_\phi(z \mid X)}[\log p_\theta(X \mid z)]-\beta D_{K L}[q_\phi(z \mid X) \| p_\theta(z)]$$
In order to implement this loss you will need to think carefully about your model's outputs and the choice of prior.
There are multiple accepted solutions. Explain your design choices based on the assumptions you make regarding the distribution of your data.
Hint: this refers to the log likelihood as mentioned in the tutorial. Make sure these assumptions reflect on the values of your input data, i.e. depending on your choice you might need to do a simple preprocessing step.
You are encouraged to experiment with the weighting coefficient $\beta$ and observe how it affects your training
from tqdm import tqdm
mu.size()
torch.linalg.norm(mu, dim=1, ord=2).size()
logvar.size()
torch.sum(logvar, dim=1).size()
abs(kld)
mu_L2 = torch.linalg.norm(mu, dim=0, ord=2)
std_L2 = torch.linalg.norm(torch.sqrt(logvar.exp()), dim=0, ord=2)
logvar_dot = torch.sum(logvar, dim=0)
0.5 * torch.sum(mu_L2 + std_L2 - latent_dim - logvar_dot)
#logvar.size()
tensor(-78.4236, device='cuda:0', grad_fn=<MulBackward0>)
# *CODE FOR PART 1.1b IN THIS CELL*
def loss_function_VAE(recon_x, x, mu, logvar, latent_dim, beta):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
BCE = F.binary_cross_entropy(recon_x, x, size_average=False, reduction='sum')
#BCE = F.mse_loss(recon_x, x, size_average=False)
# Analytic form for the KL regularizer from the slides
# 0.5 * sum_over_batch(norm2(mu) + norm2(std) - latent_dim - (sum_log_var))
mu_L2 = torch.linalg.norm(mu, dim=1, ord=2)
std_L2 = torch.linalg.norm(torch.sqrt(logvar.exp()), dim=1, ord=2)
logvar_dot = torch.sum(logvar, dim=1)
KLD = 0.5 * torch.sum(mu_L2 + std_L2 - latent_dim - logvar_dot)
return BCE + beta*KLD, BCE, KLD
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
model.train()
# <- You may wish to add logging info here
train_losses = []
train_bces = []
train_klds = []
test_losses = []
test_bces = []
test_klds = []
for epoch in range(num_epochs):
# <- You may wish to add logging info here
#Training loop
train_loss = []
train_bce = []
train_kld = []
with tqdm(loader_train, unit="batch") as tepoch:
for batch_idx, (data, _) in enumerate(tepoch):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
recon_data, mu, logvar = model(data.to(device)) # Need at least one batch/random data with right shape -
# This is required to initialize to model properly below
# when we save the computational graph for testing (jit.save)
recon_data = torch.nan_to_num(recon_data)
mu = torch.nan_to_num(mu)
logvar = torch.nan_to_num(logvar)
loss, bce, kld = loss_function_VAE(recon_data, data.to(device), mu, logvar, latent_dim, beta=0.5)
#Save training losses informations
train_loss.append(loss.item()/len(data))
train_bce.append(bce.item()/len(data))
train_kld.append(kld.item()/len(data))
optimizer.zero_grad()
loss.backward()
optimizer.step()
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
if batch_idx % 20 == 0:
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item()/len(data))
#Testing loop
test_loss = []
test_bce = []
test_kld = []
with tqdm(loader_test, unit="batch") as tepoch:
for batch_idx, (data, _) in enumerate(tepoch):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
model.eval()
recon_data, mu, logvar = model(data.to(device)) # Need at least one batch/random data with right shape -
# This is required to initialize to model properly below
# when we save the computational graph for testing (jit.save)
recon_data = torch.nan_to_num(recon_data)
mu = torch.nan_to_num(mu)
logvar = torch.nan_to_num(logvar)
loss, bce, kld = loss_function_VAE(recon_data, data.to(device), mu, logvar, latent_dim, beta=0.5)
#Save testing losses informations
test_loss.append(loss.item()/len(data))
test_bce.append(bce.item()/len(data))
test_kld.append(kld.item()/len(data))
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
if batch_idx % 20 == 0:
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item()/len(data))
train_losses.append(train_loss)
train_bces.append(train_bce)
train_klds.append(train_kld)
test_losses.append(test_loss)
test_bces.append(test_bce)
test_klds.append(test_kld)
# save the model
if epoch == num_epochs - 1:
with torch.no_grad():
torch.jit.save(torch.jit.trace(model, (data.to(device)), check_trace=False),
content_path/'CW_VAE/VAE_model.pth')
Epoch 0: 100%|██████████| 469/469 [00:20<00:00, 23.23batch/s, loss=99.8] Epoch 0: 100%|██████████| 79/79 [00:02<00:00, 29.48batch/s, loss=98] Epoch 1: 100%|██████████| 469/469 [00:19<00:00, 23.47batch/s, loss=88.1] Epoch 1: 100%|██████████| 79/79 [00:02<00:00, 29.63batch/s, loss=92.2] Epoch 2: 100%|██████████| 469/469 [00:20<00:00, 23.25batch/s, loss=92] Epoch 2: 100%|██████████| 79/79 [00:02<00:00, 29.49batch/s, loss=93.2] Epoch 3: 100%|██████████| 469/469 [00:20<00:00, 23.29batch/s, loss=87] Epoch 3: 100%|██████████| 79/79 [00:02<00:00, 29.86batch/s, loss=84.8] Epoch 4: 100%|██████████| 469/469 [00:20<00:00, 23.28batch/s, loss=86.2] Epoch 4: 100%|██████████| 79/79 [00:02<00:00, 29.29batch/s, loss=87.2] Epoch 5: 100%|██████████| 469/469 [00:20<00:00, 23.09batch/s, loss=81.4] Epoch 5: 100%|██████████| 79/79 [00:02<00:00, 29.41batch/s, loss=82] Epoch 6: 100%|██████████| 469/469 [00:19<00:00, 23.47batch/s, loss=80.1] Epoch 6: 100%|██████████| 79/79 [00:02<00:00, 29.16batch/s, loss=81.1] Epoch 7: 100%|██████████| 469/469 [00:19<00:00, 23.50batch/s, loss=82.8] Epoch 7: 100%|██████████| 79/79 [00:02<00:00, 29.72batch/s, loss=81.5] Epoch 8: 100%|██████████| 469/469 [00:20<00:00, 23.18batch/s, loss=79.1] Epoch 8: 100%|██████████| 79/79 [00:02<00:00, 29.54batch/s, loss=80.9] Epoch 9: 100%|██████████| 469/469 [00:19<00:00, 23.54batch/s, loss=77.8] Epoch 9: 100%|██████████| 79/79 [00:02<00:00, 29.67batch/s, loss=80.3]
plt.figure()
plt.gray()
ax1 = plt.subplot(1, 2, 1)
ax1.imshow(data[0,0,:,:].numpy())
ax2 = plt.subplot(1, 2, 2)
ax2.imshow(recon_data[0,0,:,:].cpu().detach().numpy())
#data[0,:,:,:].numpy()
<matplotlib.image.AxesImage at 0x7fc04c701a90>
Explain your choice of loss and how this relates to:
# Any code for your explanation here
YOUR ANSWER
The loss contains two terms, the first one is the BCE (reconstruction term) between the reconstructed data and the data and the second one is the KLD (regularisation term) which see how far the latent space distribution is from the VAE prior. The second term is here to make sure that the latent space is well organized. This will avoid overfitting and will give two main properties to the VAE which are the continuity (two samples in the latent should give close decoded outputs), the completeness (every sample from the latent distribution will return sort of meaningful content) and a distangled latent space (one z sample can only be decoded into 1 unique reconstructed image.)
a. Plot your loss curves
b. Show reconstructions and samples
c. Discuss your results from parts (a) and (b)
Plot your loss curves (6 in total, 3 for the training set and 3 for the test set): total loss, reconstruction log likelihood loss, KL loss (x-axis: epochs, y-axis: loss). If you experimented with different values of $\beta$, you may wish to display multiple plots (worth 1 point).
#len(train_losses_avg)
train_losses_avg = np.mean(np.array(train_losses), axis=1)
train_bces_avg = np.mean(np.array(train_bces), axis=1)
train_klds_avg = np.mean(np.array(train_klds), axis=1)
test_losses_avg = np.mean(np.array(test_losses), axis=1)
test_bces_avg = np.mean(np.array(test_bces), axis=1)
test_klds_avg = np.mean(np.array(test_klds), axis=1)
plt.figure(figsize=(30, 20))
ax1 = plt.subplot(3, 1, 1)
ax1.plot(train_losses_avg, label="Train Total Loss")
ax1.plot(test_losses_avg, label="Test Total Loss")
ax1.set_xlabel('Epoch number', fontsize=15)
ax1.set_ylabel('Loss value', fontsize=15)
ax1.set_title("Total , fontsize=15Loss")
ax1.legend(loc="upper right", prop={"size":15})
ax2 = plt.subplot(3, 1, 2)
ax2.plot(train_klds_avg, label="Train KLD")
ax2.plot(test_klds_avg, label="Test KLD")
ax2.set_xlabel('Epoch number', fontsize=15)
ax2.set_ylabel('Loss value', fontsize=15)
ax2.set_title("KLD", fontsize=15)
ax2.legend(loc="upper right", prop={"size":15})
ax3 = plt.subplot(3, 1, 3)
ax3.plot(train_bces_avg, label="Train BCE")
ax3.plot(test_bces_avg, label="Test BCE")
ax3.set_xlabel('Epoch number', fontsize=15)
ax3.set_ylabel('Loss value', fontsize=15)
ax3.set_title("BCE", fontsize=15)
ax3.legend(loc="upper right", prop={"size":15})
<matplotlib.legend.Legend at 0x7fc04c54ab20>
# *PLOT LOSS FOR DIFFERENT BETAS*
all_beta_train_losses = {}
all_beta_train_bces = {}
all_beta_train_klds = {}
all_beta_test_losses = {}
all_beta_test_bces = {}
all_beta_test_klds = {}
betas = [0, 0.5, 1]
for beta in betas:
print(f"beta={beta}")
beta_model = VAE(latent_dim).to(device)
beta_optimizer = torch.optim.Adam(beta_model.parameters(), lr=learning_rate)
beta_model.train()
# <- You may wish to add logging info here
beta_train_losses = []
beta_train_bces = []
beta_train_klds = []
beta_test_losses = []
beta_test_bces = []
beta_test_klds = []
for epoch in range(num_epochs):
# <- You may wish to add logging info here
#Training loop
train_loss = []
train_bce = []
train_kld = []
with tqdm(loader_train, unit="batch") as tepoch:
for batch_idx, (data, _) in enumerate(tepoch):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
recon_data, mu, logvar = beta_model(data.to(device)) # Need at least one batch/random data with right shape -
# This is required to initialize to model properly below
# when we save the computational graph for testing (jit.save)
recon_data = torch.nan_to_num(recon_data)
mu = torch.nan_to_num(mu)
logvar = torch.nan_to_num(logvar)
loss, bce, kld = loss_function_VAE(recon_data, data.to(device), mu, logvar, latent_dim, beta=beta)
#Save training losses informations
train_loss.append(loss.item()/len(data))
train_bce.append(bce.item()/len(data))
train_kld.append(kld.item()/len(data))
beta_optimizer.zero_grad()
loss.backward()
beta_optimizer.step()
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
if batch_idx % 20 == 0:
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item()/len(data))
#Testing loop
test_loss = []
test_bce = []
test_kld = []
with tqdm(loader_test, unit="batch") as tepoch:
for batch_idx, (data, _) in enumerate(tepoch):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
beta_model.eval()
recon_data, mu, logvar = beta_model(data.to(device)) # Need at least one batch/random data with right shape -
# This is required to initialize to model properly below
# when we save the computational graph for testing (jit.save)
recon_data = torch.nan_to_num(recon_data)
mu = torch.nan_to_num(mu)
logvar = torch.nan_to_num(logvar)
loss, bce, kld = loss_function_VAE(recon_data, data.to(device), mu, logvar, latent_dim, beta=beta)
#Save testing losses informations
test_loss.append(loss.item()/len(data))
test_bce.append(bce.item()/len(data))
test_kld.append(kld.item()/len(data))
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
if batch_idx % 20 == 0:
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item()/len(data))
beta_train_losses.append(train_loss)
beta_train_bces.append(train_bce)
beta_train_klds.append(train_kld)
beta_test_losses.append(test_loss)
beta_test_bces.append(test_bce)
beta_test_klds.append(test_kld)
all_beta_train_losses[beta] = beta_train_losses
all_beta_train_bces[beta] = beta_train_bces
all_beta_train_klds[beta] = beta_train_klds
all_beta_test_losses[beta] = beta_test_losses
all_beta_test_bces[beta] = beta_test_bces
all_beta_test_klds[beta] = beta_test_klds
beta=0
0%| | 0/469 [00:00<?, ?batch/s]/opt/conda/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. warnings.warn(warning.format(ret)) Epoch 0: 100%|██████████| 469/469 [00:20<00:00, 22.99batch/s, loss=101] Epoch 0: 100%|██████████| 79/79 [00:02<00:00, 29.45batch/s, loss=104] Epoch 1: 100%|██████████| 469/469 [00:20<00:00, 23.43batch/s, loss=87.7] Epoch 1: 100%|██████████| 79/79 [00:02<00:00, 29.19batch/s, loss=86.8] Epoch 2: 100%|██████████| 469/469 [00:20<00:00, 23.29batch/s, loss=85.1] Epoch 2: 100%|██████████| 79/79 [00:02<00:00, 29.58batch/s, loss=83.3] Epoch 3: 100%|██████████| 469/469 [00:20<00:00, 23.14batch/s, loss=81.5] Epoch 3: 100%|██████████| 79/79 [00:02<00:00, 29.32batch/s, loss=82.5] Epoch 4: 100%|██████████| 469/469 [00:20<00:00, 23.14batch/s, loss=81] Epoch 4: 100%|██████████| 79/79 [00:02<00:00, 29.33batch/s, loss=81] Epoch 5: 100%|██████████| 469/469 [00:20<00:00, 23.14batch/s, loss=79.7] Epoch 5: 100%|██████████| 79/79 [00:02<00:00, 28.80batch/s, loss=79] Epoch 6: 100%|██████████| 469/469 [00:20<00:00, 23.11batch/s, loss=77.6] Epoch 6: 100%|██████████| 79/79 [00:02<00:00, 29.35batch/s, loss=79.6] Epoch 7: 100%|██████████| 469/469 [00:20<00:00, 23.23batch/s, loss=78.8] Epoch 7: 100%|██████████| 79/79 [00:02<00:00, 29.75batch/s, loss=78.4] Epoch 8: 100%|██████████| 469/469 [00:19<00:00, 23.59batch/s, loss=83.1] Epoch 8: 100%|██████████| 79/79 [00:02<00:00, 29.54batch/s, loss=78.2] Epoch 9: 100%|██████████| 469/469 [00:19<00:00, 23.54batch/s, loss=77.4] Epoch 9: 100%|██████████| 79/79 [00:02<00:00, 28.99batch/s, loss=78]
beta=0.5
Epoch 0: 100%|██████████| 469/469 [00:20<00:00, 23.03batch/s, loss=104] Epoch 0: 100%|██████████| 79/79 [00:02<00:00, 29.49batch/s, loss=128] Epoch 1: 100%|██████████| 469/469 [00:20<00:00, 23.06batch/s, loss=97.2] Epoch 1: 100%|██████████| 79/79 [00:02<00:00, 29.01batch/s, loss=99.6] Epoch 2: 100%|██████████| 469/469 [00:20<00:00, 23.28batch/s, loss=84.2] Epoch 2: 100%|██████████| 79/79 [00:02<00:00, 28.78batch/s, loss=88] Epoch 3: 100%|██████████| 469/469 [00:20<00:00, 23.19batch/s, loss=83.2] Epoch 3: 100%|██████████| 79/79 [00:02<00:00, 29.03batch/s, loss=91.4] Epoch 4: 100%|██████████| 469/469 [00:20<00:00, 22.98batch/s, loss=79.8] Epoch 4: 100%|██████████| 79/79 [00:02<00:00, 28.82batch/s, loss=85.4] Epoch 5: 100%|██████████| 469/469 [00:20<00:00, 23.23batch/s, loss=85.7] Epoch 5: 100%|██████████| 79/79 [00:02<00:00, 29.02batch/s, loss=86.9] Epoch 6: 100%|██████████| 469/469 [00:20<00:00, 23.21batch/s, loss=86.9] Epoch 6: 100%|██████████| 79/79 [00:02<00:00, 29.88batch/s, loss=83.2] Epoch 7: 100%|██████████| 469/469 [00:20<00:00, 23.11batch/s, loss=84.9] Epoch 7: 100%|██████████| 79/79 [00:02<00:00, 29.32batch/s, loss=82.6] Epoch 8: 100%|██████████| 469/469 [00:20<00:00, 23.09batch/s, loss=82.1] Epoch 8: 100%|██████████| 79/79 [00:02<00:00, 29.14batch/s, loss=82.9] Epoch 9: 100%|██████████| 469/469 [00:20<00:00, 23.04batch/s, loss=79.9] Epoch 9: 100%|██████████| 79/79 [00:02<00:00, 29.07batch/s, loss=80.9]
beta=1
Epoch 0: 100%|██████████| 469/469 [00:20<00:00, 23.08batch/s, loss=110] Epoch 0: 100%|██████████| 79/79 [00:02<00:00, 29.23batch/s, loss=126] Epoch 1: 100%|██████████| 469/469 [00:20<00:00, 23.11batch/s, loss=128] Epoch 1: 100%|██████████| 79/79 [00:02<00:00, 29.07batch/s, loss=111] Epoch 2: 100%|██████████| 469/469 [00:20<00:00, 22.98batch/s, loss=79.2] Epoch 2: 100%|██████████| 79/79 [00:02<00:00, 29.16batch/s, loss=84.7] Epoch 3: 100%|██████████| 469/469 [00:20<00:00, 23.21batch/s, loss=82.9] Epoch 3: 100%|██████████| 79/79 [00:02<00:00, 29.81batch/s, loss=85.8] Epoch 4: 100%|██████████| 469/469 [00:20<00:00, 23.21batch/s, loss=81.2] Epoch 4: 100%|██████████| 79/79 [00:02<00:00, 29.11batch/s, loss=83.7] Epoch 5: 100%|██████████| 469/469 [00:20<00:00, 23.06batch/s, loss=77.4] Epoch 5: 100%|██████████| 79/79 [00:02<00:00, 28.96batch/s, loss=82.5] Epoch 6: 100%|██████████| 469/469 [00:20<00:00, 23.22batch/s, loss=82.3] Epoch 6: 100%|██████████| 79/79 [00:02<00:00, 29.07batch/s, loss=79.1] Epoch 7: 100%|██████████| 469/469 [00:20<00:00, 23.14batch/s, loss=77.6] Epoch 7: 100%|██████████| 79/79 [00:02<00:00, 29.33batch/s, loss=76.6] Epoch 8: 100%|██████████| 469/469 [00:20<00:00, 23.14batch/s, loss=74.3] Epoch 8: 100%|██████████| 79/79 [00:02<00:00, 28.87batch/s, loss=95.9] Epoch 9: 100%|██████████| 469/469 [00:20<00:00, 23.15batch/s, loss=77.9] Epoch 9: 100%|██████████| 79/79 [00:02<00:00, 29.27batch/s, loss=77]
# *CODE FOR PART 1.2a IN THIS CELL*
def plot_beta_curves(all_train_losses,
all_train_bces,
all_train_klds,
all_test_losses,
all_test_bces,
all_test_klds,
beta=0.5):
# Format all arrays
train_losses_avg = np.mean(np.array(all_train_losses[beta]), axis=1)
train_bces_avg = np.mean(np.array(all_train_bces[beta]), axis=1)
train_klds_avg = np.mean(np.array(all_train_klds[beta]), axis=1)
test_losses_avg = np.mean(np.array(all_test_losses[beta]), axis=1)
test_bces_avg = np.mean(np.array(all_test_bces[beta]), axis=1)
test_klds_avg = np.mean(np.array(all_test_klds[beta]), axis=1)
plt.figure(figsize=(30, 20))
plt.plot(train_losses_avg, label="Train Loss = BCE - KLD")
plt.plot(test_losses_avg, label="Test Loss = BCE - KLD")
plt.plot(-train_klds_avg, label="Train KLD")
plt.plot(-test_klds_avg, label="Test KLD")
plt.plot(train_bces_avg, label="Train BCE")
plt.plot(test_bces_avg, label="Test BCE")
plt.xlabel('Epoch number', fontsize=15)
plt.ylabel('Loss value', fontsize=15)
plt.title(f"Loss curves for beta={beta}", fontsize=15)
plt.legend(loc="upper right", prop={"size":15})
plot_beta_curves(all_beta_train_losses,
all_beta_train_bces,
all_beta_train_klds,
all_beta_test_losses,
all_beta_test_bces,
all_beta_test_klds,
beta=0)
plot_beta_curves(all_beta_train_losses,
all_beta_train_bces,
all_beta_train_klds,
all_beta_test_losses,
all_beta_test_bces,
all_beta_test_klds,
beta=0.5)
plot_beta_curves(all_beta_train_losses,
all_beta_train_bces,
all_beta_train_klds,
all_beta_test_losses,
all_beta_test_bces,
all_beta_test_klds,
beta=1)
YOUR ANSWER
We can see that the higher the beta the better the regularization, bus as the KL value is really low compared to the BCE value, it doesn't have a huge impact. One solution to cope with this could be to normalize the BCE to have both values in the same order of magnitude.
Visualize a subset of the images of the test set and their reconstructions as well as a few generated samples. Most of the code for this part is provided. You only need to call the forward pass of the model for the given inputs (might vary depending on your implementation).
For reference, here's some samples from our VAE.
# *CODE FOR PART 1.2b IN THIS CELL*
# load the model
print('Input images')
print('-'*50)
sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[0:32, :, :, :]
# visualize the original images of the last batch of the test set
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure()
show(img)
print('Reconstructed images')
print('-'*50)
with torch.no_grad():
# visualize the reconstructed images of the last batch of test set
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
recon_batch, _, _ = model(fixed_input.to(device))
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
recon_batch = recon_batch.cpu()
recon_batch = make_grid(denorm(recon_batch), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure()
show(recon_batch)
print('Generated Images')
print('-'*50)
model.eval()
n_samples = 256
z = torch.randn(n_samples,latent_dim).to(device)
with torch.no_grad():
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
samples = model.decode(z)
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
samples = samples.cpu()
samples = make_grid(denorm(samples), nrow=16, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure(figsize = (8,8))
show(samples)
Input images -------------------------------------------------- Reconstructed images -------------------------------------------------- Generated Images --------------------------------------------------
/opt/conda/lib/python3.8/site-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead. warnings.warn(warning)
Provide a brief analysis of your loss curves and reconstructions:
YOUR ANSWER
Qualitative analysis of the learned representations
In this question you are asked to qualitatively assess the representations that your model has learned. In particular:
a. Dimensionality Reduction of learned embeddings
b. Interpolating in the latent space
Extract the latent representations of the test set and visualize them using T-SNE (see implementation). You can use a T-SNE implementation from a library such as scikit-learn.
We've provided a function to visualize a subset of the data, but you are encouraged to also produce a matplotlib plot (please use different colours for each digit class).
# *CODE FOR PART 1.3a IN THIS CELL
from sklearn.manifold import TSNE
# Interactive Visualization - Code Provided
test_dataloader = DataLoader(test_dat, 10000, shuffle=False)
""" Inputs to the function are
z_embedded - Embedded X, Y positions for every point in test_dataloader
test_dataloader - dataloader with batchsize set to 10000
num_points - number of points plotted (will slow down with >1k)
"""
test_sample_inputs, _ = next(iter(test_dataloader))
z, _, _ = model.encode(test_sample_inputs.to(device))
z = z.cpu()
# Custom Visualizations
perplexities = [5, 30, 50]
perplexity = perplexities[0]
print(f"perplexity={perplexity}")
z_embedded = TSNE(n_components=2, learning_rate='auto', perplexity = perplexity,
init='random').fit_transform(z.detach().numpy())
plot_tsne(z_embedded, test_dataloader, num_points=1000, darkmode=False)
perplexity=5
perplexity = perplexities[1]
print(f"perplexity={perplexity}")
z_embedded = TSNE(n_components=2, learning_rate='auto', perplexity = perplexity,
init='random').fit_transform(z.detach().numpy())
plot_tsne(z_embedded, test_dataloader, num_points=1000, darkmode=False)
perplexity=30
perplexity = perplexities[2]
print(f"perplexity={perplexity}")
z_embedded = TSNE(n_components=2, learning_rate='auto', perplexity = perplexity,
init='random').fit_transform(z.detach().numpy())
plot_tsne(z_embedded, test_dataloader, num_points=1000, darkmode=False)
perplexity=50
What do you observe? Discuss the structure of the visualized representations.
Note - If you created multiple plots and want to include them in your discussion, the best option is to upload them to (e.g.) google drive and then embed them via a public share link. If you reference local files, please include these in your submission zip, and use relative pathing if you are embedding them (with the notebook in the base directory)
YOUR ANSWER
We can observe that there are still few outliers in all the clusters. Actually the KL term should have helped to encode them correctly, but if we look closer to these outliers we can clearly see that even a human would have difficulties to read them. Thus it can either be wrongly labeled ones or really ambiguous ones. Moreover, we can see that the higher the perplexity, the more the clusters are separated from each other. However in the plot with perplexity=5, we can already identify the boundaries. So there's no need to increase the perplexity more ase it also shows the completeness and continuity of the VAE model we trained. Last but not least it's important to keep in mind that we are projecting vectors from 8 to 2 dimensions, so there's a lot of different configurations but it doesn't represent how the data point are organized in 8 dimensions which is what we need to fully understand the interpolation right after. Finally I would say that the t-SNE is a good tool as we can only plot lower than 3D dim and it looks quite reliable.
Perform a linear interpolation in the latent space of the autoencoder by choosing any two digits from the test set. What do you observe regarding the transition from on digit to the other?
# CODE FOR PART 1.3b IN THIS CELL
index_1 = 35
index_2 = 45
def encode_image(model, loader, index):
image = test_sample_inputs[index:index+1,:,:,:]
z, _, _ = model.encode(image.to(device))
z = z.cpu()
return z
z_1 = encode_image(model, test_sample_inputs, index_1)
z_2 = encode_image(model, test_sample_inputs, index_2)
ratio_range = np.arange(0,1.1,0.2)
z_batch = torch.zeros(len(ratio_range), latent_dim)
for index, ratio in enumerate(ratio_range):
z = ratio*z_1 + (1-ratio)*z_2
z_batch[index, :] = z
samples = model.decode(z_batch.to(device))
samples = samples.cpu()
plt.figure(figsize=(20, 10))
plt.gray()
for index in range(len(ratio_range)):
ax = plt.subplot(1, len(ratio_range), index+1)
ax.imshow(samples[index,0,:,:].detach().numpy())
What did you observe in the interpolation? Is this what you expected?
YOUR ANSWER In my interpolation, I observe that the 3 is actually somehow in between the 5 and the 2 in the latent space (3 dimensions). In the first steps the upper right corner of the 5 is being rounded which makes it look like an 3. Then it starts to brak the lines on the upper left and the lower right corners. It show how continuous is the method. Concerning the relation with the T-SNE visualization, it appears that sometimes the t-SNE representation locates the 3 in between 5 and 2 or close from each other which is the case up there. But as it is in 2 dimensions, we cannot always see if the 3 would be in between the 5 and 2 in latent_dim dimensions.
In this task, your main objective is to train a DCGAN (https://arxiv.org/abs/1511.06434) on the CIFAR-10 dataset. You should experiment with different architectures and tricks for stability in training (such as using different activation functions, batch normalization, different values for the hyper-parameters, etc.). In the end, you should provide us with:
Your Task:
a. Implement the DCGAN architecture.
b. Define a loss and implement the Training Loop
c. Visualize images sampled from your best model's generator ("Extension" Assessed on quality)
d. Discuss the experimentations which led to your final architecture. You can plot losses or generated results by other architectures that you tested to back your arguments (but this is not necessary to get full marks).
Clarification: You should not be worrying too much about getting an "optimal" performance on your trained GAN. We want you to demonstrate to us that you experimented with different types of DCGAN variations, report what difficulties transpired throughout the training process, etc. In other words, if we see that you provided us with a running implementation, that you detail different experimentations that you did before providing us with your best one, and that you have grapsed the concepts, you can still get good marks. The attached model does not have to be perfect, and the extension marks for performance are only worth 10 points.
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import torch.nn.functional as F
import matplotlib.pyplot as plt
mean = torch.Tensor([0.4914, 0.4822, 0.4465])
std = torch.Tensor([0.247, 0.243, 0.261])
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
def denorm(x, channels=None, w=None ,h=None, resize = False):
x = unnormalize(x)
if resize:
if channels is None or w is None or h is None:
print('Number of channels, width and height must be provided for resize.')
x = x.view(x.size(0), channels, w, h)
return x
def show(img):
npimg = img.cpu().numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
if not os.path.exists(content_path/'CW_GAN'):
os.makedirs(content_path/'CW_GAN')
if not os.path.exists(content_path/'CW_GAN/learning_curves'):
os.makedirs(content_path/'CW_GAN/learning_curves')
GPU = True # Choose whether to use GPU
if GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
print(f'Using {device}')
# We set a random seed to ensure that your results are reproducible.
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
Using cuda
<torch._C.Generator at 0x7fc170234bd0>
Fill in the missing parts in the cells below in order to complete the Generator and Discriminator classes. You will need to define:
decodediscriminatorRecomendations for experimentation:
Some general reccomendations:
torch.nn.Upsample) or transposed convolutions(torch.nn.ConvTranspose2d). torch.nn.BatchNorm2d) and leaky relu (torch.nn.LeakyReLu) units after each convolutional layer.Try to follow the common practices for CNNs (e.g small kernels, max pooling, RELU activations), in order to narrow down your possible choices.
**Your model should not have more than 25 Million Parameters**
The number of epochs that will be needed in order to train the network will vary depending on your choices. As an advice, we recommend that while experimenting you should allow around 20 epochs and if the loss doesn't sufficiently drop, restart the training with a more powerful architecture. You don't need to train the network to an extreme if you don't have the time.
batch_size = 64 # change that
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# note - data_path was initialized at the top of the notebook
cifar10_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=batch_size)
loader_test = DataLoader(cifar10_test, batch_size=batch_size)
Files already downloaded and verified Files already downloaded and verified
We'll visualize a subset of the test set:
samples, _ = next(iter(loader_test))
samples = samples.cpu()
samples = make_grid(denorm(samples), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure(figsize = (15,15))
plt.axis('off')
show(samples)
/opt/conda/lib/python3.8/site-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead. warnings.warn(warning)
samples, _ = next(iter(loader_test))
samples.shape
torch.Size([64, 3, 32, 32])
Define hyperparameters and the model
# *CODE FOR PART 2.1 IN THIS CELL*
# Choose the number of epochs, the learning rate
# and the size of the Generator's input noise vetor.
# Number of training epochs
num_epochs = 50
# Learning rate for optimizers
learning_rate = 0.0002
# Size of z latent vector (i.e. size of generator input)
latent_vector_size = 100
### Other hyperparams ###
# Spatial size of training images.
# All images will be resized to this size using a transformer.
image_size = 32
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of feature maps in generator
ngf = 128
# Size of feature maps in discriminator
ndf = 128
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
latent_test = torch.randn(batch_size, latent_vector_size, 1, 1)
latent_test_1 = nn.ConvTranspose2d(latent_vector_size, ngf * 4, 4, 1, 1, bias=False)(latent_test)
latent_test_2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)(latent_test_1)
latent_test_3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False)(latent_test_2)
latent_test_4 = nn.ConvTranspose2d( ngf, ngf//2, 4, 2, 1, bias=False)(latent_test_3)
last_latent = nn.ConvTranspose2d( ngf//2, nc, 4, 2, 1, bias=False)(latent_test_4)
last_latent.shape
torch.Size([64, 3, 32, 32])
#Too many parameters in the end...
img_test = torch.randn(batch_size, nc, image_size, image_size)
img_test_1 = nn.Conv2d(nc, ndf, 4, 1, 2, bias=False)(img_test)
img_test_2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)(img_test_1)
img_test_3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)(img_test_2)
img_test_4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)(img_test_3)
last_img = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)(img_test_4)
last_img.shape
torch.Size([64, 1, 1, 1])
# *CODE FOR PART 2.1 IN THIS CELL*
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( latent_vector_size, ngf * 4, 4, 1, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.Dropout(p=0.1),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ngf*4) x 4 x 4 = 256 x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ngf*2) x 8 x 8 = = 128 x 8 x 8
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ngf) x 16 x 16 = = 64 x 16 x 16
nn.ConvTranspose2d( ngf, ngf//2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf//2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ngf//2) x 16 x 16 = = 64 x 16 x 16
nn.ConvTranspose2d( ngf//2, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 32 x 32
)
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
# You can modify the arguments of this function if needed
def forward(self, z):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
out = self.main(z)
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
self.main = nn.Sequential(
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
# You can modify the arguments of this function if needed
def forward(self, x):
#######################################################################
# ** START OF YOUR CODE **
#######################################################################
out = self.main(x)
#######################################################################
# ** END OF YOUR CODE **
#######################################################################
return out
You can use method weights_init to initialize the weights of the Generator and Discriminator networks. Otherwise, implement your own initialization, or do not use at all. You will not be penalized for not using initialization.
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
use_weights_init = True
model_G = Generator().to(device)
if use_weights_init:
model_G.apply(weights_init)
params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad)
print("Total number of parameters in Generator is: {}".format(params_G))
print(model_G)
print('\n')
model_D = Discriminator().to(device)
if use_weights_init:
model_D.apply(weights_init)
params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
print("Total number of parameters in Discriminator is: {}".format(params_D))
print(model_D)
print('\n')
print("Total number of parameters is: {}".format(params_G + params_D))
Total number of parameters in Generator is: 3576704
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Dropout(p=0.1, inplace=False)
(3): LeakyReLU(negative_slope=0.2, inplace=True)
(4): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): LeakyReLU(negative_slope=0.2, inplace=True)
(7): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): LeakyReLU(negative_slope=0.2, inplace=True)
(10): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): LeakyReLU(negative_slope=0.2, inplace=True)
(13): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(14): Tanh()
)
)
Total number of parameters in Discriminator is: 2637568
Discriminator(
(main): Sequential(
(0): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2, inplace=True)
(3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): LeakyReLU(negative_slope=0.2, inplace=True)
(6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): LeakyReLU(negative_slope=0.2, inplace=True)
(9): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(10): Sigmoid()
)
)
Total number of parameters is: 6214272
# You can modify the arguments of this function if needed
def loss_function(output, label):
#criterion = nn.BCELoss(reduction='mean')
criterion = nn.BCELoss()
return criterion(output, label)
# setup optimizer
# You are free to add a scheduler or change the optimizer if you want. We chose one for you for simplicity.
#optimizerD = torch.optim.SGD(model_D.parameters(), lr=learning_rate)
optimizerD = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=(beta1, 0.999))
fixed_noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)
# Additional input variables should be defined here
#torch.cuda.empty_cache()
Complete the training loop below. We've defined some variables to keep track of things during training:
import tqdm
train_losses_G = []
train_losses_D = []
errG_list = []
errD_list = []
# <- You may wish to add logging info here
for epoch in range(num_epochs):
# <- You may wish to add logging info here
train_loss_D = 0
train_loss_G = 0
with tqdm.tqdm(loader_train, unit="batch") as tepoch:
for i, data in enumerate(tepoch):
#######################################################################
# * START OF YOUR CODE *
#######################################################################
##### (1) UPDATE D network: maximize log(D(x)) + log(1 - D(G(z))) #####
# train with real
model_D.zero_grad()
# Format batch
data = data[0]
new_batch_size = data.shape[0]
real_data = data.to(device)
# Adding gaussian noise to prevent from collapse mode
gauss_noise = torch.normal(0, 1/np.exp(epoch), size=(new_batch_size, nc, image_size, image_size), device=device)
real_data += gauss_noise
# Label smoothing to confuse the discriminator (fake label)
rand_real_label = np.random.uniform(0.7, 1)
real_label = torch.full((new_batch_size,), rand_real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
real_output = model_D(real_data).view(-1)
# Calculate loss on all-real batch
real_errD = loss_function(real_output, real_label)
real_errD.backward()
D_x = real_output.mean().item()
# train with fake
# Generate batch of latent vectors
latent_vectors = torch.randn(new_batch_size, latent_vector_size, 1, 1, device=device)
# Generate fake image batch with G
fake_data = model_G(latent_vectors)
# Label smoothing to confuse the discriminator (fake label)
rand_fake_label_D = np.random.uniform(0, 0.3)
fake_label_D = torch.full((new_batch_size,), rand_fake_label_D, dtype=torch.float, device=device)
# Classify all fake batch with D
fake_output = model_D(fake_data.detach()).view(-1)
fake_errD = loss_function(fake_output, fake_label_D)
fake_errD.backward()
# Compute D's loss on the all-fake batch
D_G_z1 = fake_output.mean().item()
# Compute the total loss
errD = fake_errD + real_errD
# Update D
optimizerD.step()
##### (2) UPDATE G network: maximize log(D(G(z))) #####
model_G.zero_grad()
rand_real_label = np.random.uniform(0.7, 1)
label_G = torch.full((new_batch_size,), rand_real_label, dtype=torch.float, device=device)
# Now that we updated D, let's do another forward pass through D
output_G = model_D(fake_data).view(-1)
# Compute G's loss
errG = loss_function(output_G, label_G)
errG.backward()
# Compute gradients for G
D_G_z2 = output_G.mean().item()
# Update G
optimizerG.step()
# Storing losses for plots
train_loss_D += errD.item()
train_loss_G += errG.item()
errG_list.append(errG.item())
errD_list.append(errD.item())
#######################################################################
# * END OF YOUR CODE *
#######################################################################
# Logging
if i % 50 == 0:
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(D_G_z=f"{D_G_z1:.3f}/{D_G_z2:.3f}", D_x=D_x,
Loss_D=errD.item(), Loss_G=errG.item())
if epoch == 0:
save_image(denorm(real_data.cpu()).float(), content_path/'CW_GAN/real_samples.png')
with torch.no_grad():
fake = model_G(fixed_noise)
#save_image(denorm(fake.cpu()).float(), content_path/'CW_GAN/fake_samples_epoch_%03d.png')
save_image(denorm(fake.cpu()).float(), f"{content_path}/CW_GAN/fake_samples_epoch_{epoch}.png")
# % epoch
train_losses_D.append(train_loss_D / len(loader_train))
train_losses_G.append(train_loss_G / len(loader_train))
# save models
# if your discriminator/generator are conditional you'll want to change the inputs here
torch.jit.save(torch.jit.trace(model_G, (fixed_noise)), content_path/'CW_GAN/GAN_G_model.pth')
torch.jit.save(torch.jit.trace(model_D, (fake)), content_path/'CW_GAN/GAN_D_model.pth')
Epoch 0: 100%|██████████| 782/782 [00:47<00:00, 16.64batch/s, D_G_z=0.266/0.061, D_x=0.66, Loss_D=0.967, Loss_G=2.72] Epoch 1: 100%|██████████| 782/782 [00:47<00:00, 16.56batch/s, D_G_z=0.338/0.335, D_x=0.648, Loss_D=1.08, Loss_G=1.09] Epoch 2: 100%|██████████| 782/782 [00:47<00:00, 16.59batch/s, D_G_z=0.189/0.168, D_x=0.588, Loss_D=1.09, Loss_G=1.7] Epoch 3: 100%|██████████| 782/782 [00:47<00:00, 16.58batch/s, D_G_z=0.268/0.316, D_x=0.571, Loss_D=1.2, Loss_G=1.21] Epoch 4: 100%|██████████| 782/782 [00:47<00:00, 16.60batch/s, D_G_z=0.327/0.109, D_x=0.718, Loss_D=1.1, Loss_G=1.95] Epoch 5: 100%|██████████| 782/782 [00:47<00:00, 16.55batch/s, D_G_z=0.300/0.365, D_x=0.66, Loss_D=1.08, Loss_G=0.992] Epoch 6: 100%|██████████| 782/782 [00:47<00:00, 16.55batch/s, D_G_z=0.304/0.336, D_x=0.643, Loss_D=1.01, Loss_G=1.13] Epoch 7: 100%|██████████| 782/782 [00:47<00:00, 16.53batch/s, D_G_z=0.501/0.123, D_x=0.819, Loss_D=1.35, Loss_G=1.86] Epoch 8: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.349/0.591, D_x=0.68, Loss_D=1.08, Loss_G=0.674] Epoch 9: 100%|██████████| 782/782 [00:47<00:00, 16.64batch/s, D_G_z=0.381/0.268, D_x=0.801, Loss_D=1.15, Loss_G=1.2] Epoch 10: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.346/0.301, D_x=0.602, Loss_D=1.18, Loss_G=1.15] Epoch 11: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.265/0.370, D_x=0.677, Loss_D=1.15, Loss_G=0.914] Epoch 12: 100%|██████████| 782/782 [00:47<00:00, 16.59batch/s, D_G_z=0.384/0.132, D_x=0.831, Loss_D=0.847, Loss_G=2.11] Epoch 13: 100%|██████████| 782/782 [00:47<00:00, 16.64batch/s, D_G_z=0.323/0.209, D_x=0.705, Loss_D=1.09, Loss_G=1.44] Epoch 14: 100%|██████████| 782/782 [00:47<00:00, 16.58batch/s, D_G_z=0.417/0.228, D_x=0.783, Loss_D=1.02, Loss_G=1.46] Epoch 15: 100%|██████████| 782/782 [00:47<00:00, 16.54batch/s, D_G_z=0.593/0.151, D_x=0.867, Loss_D=1.32, Loss_G=1.91] Epoch 16: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.190/0.324, D_x=0.642, Loss_D=1.16, Loss_G=1.15] Epoch 17: 100%|██████████| 782/782 [00:47<00:00, 16.56batch/s, D_G_z=0.311/0.240, D_x=0.754, Loss_D=1.27, Loss_G=1.36] Epoch 18: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.247/0.204, D_x=0.699, Loss_D=1.09, Loss_G=1.62] Epoch 19: 100%|██████████| 782/782 [00:47<00:00, 16.62batch/s, D_G_z=0.265/0.402, D_x=0.749, Loss_D=0.866, Loss_G=0.907] Epoch 20: 100%|██████████| 782/782 [00:46<00:00, 16.64batch/s, D_G_z=0.225/0.410, D_x=0.62, Loss_D=1.12, Loss_G=0.901] Epoch 21: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.383/0.364, D_x=0.785, Loss_D=1.02, Loss_G=0.933] Epoch 22: 100%|██████████| 782/782 [00:47<00:00, 16.36batch/s, D_G_z=0.408/0.059, D_x=0.865, Loss_D=1.27, Loss_G=2.33] Epoch 23: 100%|██████████| 782/782 [00:47<00:00, 16.58batch/s, D_G_z=0.163/0.170, D_x=0.526, Loss_D=0.939, Loss_G=1.74] Epoch 24: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.220/0.180, D_x=0.663, Loss_D=0.96, Loss_G=1.37] Epoch 25: 100%|██████████| 782/782 [00:47<00:00, 16.62batch/s, D_G_z=0.215/0.198, D_x=0.669, Loss_D=1.06, Loss_G=1.43] Epoch 26: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.353/0.464, D_x=0.845, Loss_D=0.834, Loss_G=0.777] Epoch 27: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.215/0.242, D_x=0.684, Loss_D=0.91, Loss_G=1.45] Epoch 28: 100%|██████████| 782/782 [00:47<00:00, 16.57batch/s, D_G_z=0.237/0.307, D_x=0.732, Loss_D=0.914, Loss_G=1.12] Epoch 29: 100%|██████████| 782/782 [00:47<00:00, 16.58batch/s, D_G_z=0.370/0.303, D_x=0.836, Loss_D=1.08, Loss_G=1.09] Epoch 30: 100%|██████████| 782/782 [00:47<00:00, 16.59batch/s, D_G_z=0.175/0.077, D_x=0.73, Loss_D=0.874, Loss_G=2.38] Epoch 31: 100%|██████████| 782/782 [00:47<00:00, 16.59batch/s, D_G_z=0.370/0.084, D_x=0.841, Loss_D=0.983, Loss_G=2.15] Epoch 32: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.323/0.225, D_x=0.786, Loss_D=0.937, Loss_G=1.26] Epoch 33: 100%|██████████| 782/782 [00:47<00:00, 16.60batch/s, D_G_z=0.169/0.248, D_x=0.646, Loss_D=1.29, Loss_G=1.21] Epoch 34: 100%|██████████| 782/782 [00:47<00:00, 16.58batch/s, D_G_z=0.349/0.636, D_x=0.798, Loss_D=0.887, Loss_G=0.507] Epoch 35: 100%|██████████| 782/782 [00:47<00:00, 16.62batch/s, D_G_z=0.363/0.105, D_x=0.894, Loss_D=0.991, Loss_G=1.8] Epoch 36: 100%|██████████| 782/782 [00:47<00:00, 16.54batch/s, D_G_z=0.363/0.174, D_x=0.772, Loss_D=0.897, Loss_G=1.43] Epoch 37: 100%|██████████| 782/782 [00:47<00:00, 16.54batch/s, D_G_z=0.328/0.222, D_x=0.77, Loss_D=1.24, Loss_G=1.49] Epoch 38: 100%|██████████| 782/782 [00:47<00:00, 16.61batch/s, D_G_z=0.180/0.109, D_x=0.769, Loss_D=1.01, Loss_G=1.88] Epoch 39: 100%|██████████| 782/782 [00:47<00:00, 16.59batch/s, D_G_z=0.139/0.377, D_x=0.649, Loss_D=1.18, Loss_G=0.876] Epoch 40: 100%|██████████| 782/782 [00:46<00:00, 16.64batch/s, D_G_z=0.177/0.164, D_x=0.591, Loss_D=1.04, Loss_G=1.88] Epoch 41: 100%|██████████| 782/782 [00:46<00:00, 16.64batch/s, D_G_z=0.225/0.177, D_x=0.783, Loss_D=0.927, Loss_G=1.4] Epoch 42: 100%|██████████| 782/782 [00:47<00:00, 16.30batch/s, D_G_z=0.386/0.213, D_x=0.802, Loss_D=1.17, Loss_G=1.25] Epoch 43: 100%|██████████| 782/782 [00:48<00:00, 15.98batch/s, D_G_z=0.315/0.196, D_x=0.8, Loss_D=1.24, Loss_G=1.43] Epoch 44: 100%|██████████| 782/782 [00:49<00:00, 15.95batch/s, D_G_z=0.322/0.166, D_x=0.795, Loss_D=1.03, Loss_G=1.71] Epoch 45: 100%|██████████| 782/782 [00:48<00:00, 16.00batch/s, D_G_z=0.359/0.095, D_x=0.844, Loss_D=0.945, Loss_G=2.16] Epoch 46: 100%|██████████| 782/782 [00:48<00:00, 16.07batch/s, D_G_z=0.220/0.301, D_x=0.724, Loss_D=1.26, Loss_G=1.21] Epoch 47: 100%|██████████| 782/782 [00:48<00:00, 15.98batch/s, D_G_z=0.206/0.391, D_x=0.645, Loss_D=1.16, Loss_G=0.96] Epoch 48: 100%|██████████| 782/782 [00:49<00:00, 15.93batch/s, D_G_z=0.177/0.178, D_x=0.772, Loss_D=1.13, Loss_G=1.7] Epoch 49: 100%|██████████| 782/782 [00:48<00:00, 16.09batch/s, D_G_z=0.290/0.383, D_x=0.803, Loss_D=1.25, Loss_G=0.924] /opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py:983: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: Tensor-likes are not close! Mismatched elements: 196251 / 196608 (99.8%) Greatest absolute difference: 0.9063645005226135 at index (38, 1, 15, 15) (up to 1e-05 allowed) Greatest relative difference: 92598.15496385845 at index (12, 2, 10, 21) (up to 1e-05 allowed) _check_trace(
This part is fairly open-ended, but not worth too much so do not go crazy. The table below shows examples of what are considered good samples. Level 3 and above will get you 10/10 points, level 2 will roughly get you 5/10 points and level 1 and below will get you 0/10 points.
|
|
|
|
input_noise = torch.randn(100, latent_vector_size, 1, 1, device=device)
with torch.no_grad():
# visualize the generated images
generated = model_G(input_noise).cpu()
generated = make_grid(denorm(generated)[:100], nrow=10, padding=2, normalize=True,
range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
save_image(generated, content_path/'CW_GAN/Teaching_final.png')
show(generated) # note these are now class conditional images columns rep classes 1-10
it = iter(loader_test)
sample_inputs, _ = next(it)
fixed_input = sample_inputs[0:64, :, :, :]
# visualize the original images of the last batch of the test set for comparison
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
show(img)
Discuss the process you took to arrive at your final architecture. This should include:
Your Answer
Your task:
Plot the losses curves for the discriminator $D$ and the generator $G$ as the training progresses and explain whether the produced curves are theoretically sensible and why this is (or not) the case (x-axis: epochs, y-axis: loss).
Make sure that the version of the notebook you deliver includes these results.
# ANSWER FOR PART 2.2 IN THIS CELL*
plt.figure(figsize=(20,10))
plt.plot(train_losses_G, label="Generator")
plt.plot(train_losses_D, label="Discriminator")
plt.xlabel('Epoch number', fontsize=15)
plt.ylabel('Loss value', fontsize=15)
plt.title(f'Loss curve epoch wise', fontsize=15)
plt.legend(loc="upper right", prop={"size":15})
<matplotlib.legend.Legend at 0x7fc099b9ae80>
# ANSWER FOR PART 2.2 IN THIS CELL*
plt.figure(figsize=(20,10))
plt.plot(errG_list, label="Generator")
plt.plot(errD_list, label="Discriminator")
plt.xlabel('Batch number', fontsize=15)
plt.ylabel('Loss value', fontsize=15)
plt.title(f'Loss curve batch wise', fontsize=15)
plt.legend(loc="upper right", prop={"size":15})
<matplotlib.legend.Legend at 0x7fc099b25880>
Do your loss curves look sensible? What would you expect to see and why?
YOUR ANSWER
Your task:
Describe the what causes the phenomenon of Mode Collapse and how it may manifest in the samples from a GAN.
Based on the images created by your generator using the fixed_noise vector during training, did you notice any mode collapse? what this behaviour may be attributed to, and what did you try to eliminate / reduce it?
YOUR ANSWER
A mode collapse occurs when a generator model is only able to generate one or a few different outputs.
At the very beginning I was stuck in mode collapse for few days especially because the generator was only producing the same image, which is set of random pixel all over the square. The loss was oscillated and never learned actual representations. Then I succeed in escaping from mode collapse when I added the label smoothing and the gaussian noise to the inputs.
TAs will run this cell to ensure that your results are reproducible, and that your models have been defined suitably.
Please provide the input and output transformations required to make your VAE and GANs work. If your GAN generator requires more than just noise as input, also specify this below (there are two marked cells for you to inspect)
# If you want to run these tests yourself, change directory:
# %cd '.../dl_cw2/'
ta_data_path = "../data" # You can change this to = data_path when testing
!pip install -q torch torchvision
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
# Do not remove anything here
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt
show = lambda img: plt.imshow(np.transpose(img.cpu().numpy(), (1,2,0)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Do not change this cell!
torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
<torch._C.Generator at 0x7fc170234bd0>
############# CHANGE THESE (COPY AND PASTE FROM YOUR OWN CODE) #############
vae_transform = transforms.Compose([
transforms.ToTensor(),
])
mean = torch.Tensor([0.4914, 0.4822, 0.4465])
std = torch.Tensor([0.247, 0.243, 0.261])
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
def denorm(x, channels=None, w=None ,h=None, resize = False):
x = unnormalize(x)
if resize:
if channels is None or w is None or h is None:
print('Number of channels, width and height must be provided for resize.')
x = x.view(x.size(0), channels, w, h)
return x
def vae_denorm(x):
return x
def gan_denorm(x):
return denorm(x)
gan_latent_size = 100
# If your generator requires something other than noise as input, please specify
# two cells down from here
# Load VAE Dataset
test_dat = datasets.MNIST(ta_data_path, train=False, transform=vae_transform,
download=True)
vae_loader_test = DataLoader(test_dat, batch_size=32, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
############# MODIFY IF NEEDED #############
vae_input, _ = next(iter(vae_loader_test))
# If your generator is conditional, then please modify this input suitably
input_noise = torch.randn(100, gan_latent_size, 1, 1, device=device)
gan_input = [input_noise] # In case you want to provide a tuple, we wrap ours
# VAE Tests
# TAs will change these paths as you will have provided the model files manually
"""To TAs, you should have been creating a folder with the student uid
And the .ipynb + models in the root. Then that path is './VAE_model.pth' etc.
"""
vae = model_G = torch.jit.load('./CW_VAE/VAE_model.pth')
vae.eval()
# Check if VAE is convolutional
def recurse_cnn_check(parent, flag):
if flag:
return flag
children = list(parent.children())
if len(children) > 0:
for child in children:
flag = flag or recurse_cnn_check(child, flag)
else:
params = parent._parameters
if 'weight' in params.keys():
flag = params['weight'].ndim == 4
return flag
has_cnn = recurse_cnn_check(vae, False)
print("Used CNN" if has_cnn else "Didn't Use CNN")
vae_in = make_grid(vae_denorm(vae_input), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure()
plt.axis('off')
show(vae_in)
vae_test = vae(vae_input.to(device))[0].detach()
vae_reco = make_grid(vae_denorm(vae_test), nrow=8, padding=2, normalize=False,
range=None, scale_each=False, pad_value=0)
plt.figure()
plt.axis('off')
show(vae_reco)
Used CNN
/opt/conda/lib/python3.8/site-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead. warnings.warn(warning)
<matplotlib.image.AxesImage at 0x7fbf957282e0>
# GAN Tests
model_G = torch.jit.load('./CW_GAN/GAN_G_model.pth')
model_D = torch.jit.load('./CW_GAN/GAN_D_model.pth')
[model.eval() for model in (model_G, model_D)]
# Check that GAN doesn't have too many parameters
num_param = sum(p.numel() for p in [*model_G.parameters(),*model_D.parameters()])
print(f"Number of Parameters is {num_param} which is", "ok" if num_param<25E+6 else "not ok")
# visualize the generated images
generated = model_G(*gan_input).cpu()
generated = make_grid(gan_denorm(generated)[:100].detach(), nrow=10, padding=2, normalize=True,
range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
plt.axis('off')
show(generated)
Number of Parameters is 6214272 which is ok
<matplotlib.image.AxesImage at 0x7fc04cd12550>